Acces parameter recovery results for PH with C model

Author

Milena Musial

Published

January 31, 2024

1 Setup

rm(list=ls())
libs<-c("rstan", "gdata", "bayesplot", "stringr", "dplyr", "ggplot2", "PerformanceAnalytics")
sapply(libs, require, character.only=TRUE)
Loading required package: rstan
Loading required package: StanHeaders

rstan version 2.26.22 (Stan version 2.26.1)
For execution on a local, multicore CPU with excess RAM we recommend calling
options(mc.cores = parallel::detectCores()).
To avoid recompilation of unchanged Stan programs, we recommend calling
rstan_options(auto_write = TRUE)
For within-chain threading using `reduce_sum()` or `map_rect()` Stan functions,
change `threads_per_chain` option:
rstan_options(threads_per_chain = 1)
Loading required package: gdata

Attaching package: 'gdata'
The following object is masked from 'package:stats':

    nobs
The following object is masked from 'package:utils':

    object.size
The following object is masked from 'package:base':

    startsWith
Loading required package: bayesplot
This is bayesplot version 1.11.1
- Online documentation and vignettes at mc-stan.org/bayesplot
- bayesplot theme set to bayesplot::theme_default()
   * Does _not_ affect other ggplot2 plots
   * See ?bayesplot_theme_set for details on theme setting
Loading required package: stringr
Loading required package: dplyr

Attaching package: 'dplyr'
The following objects are masked from 'package:gdata':

    combine, first, last, starts_with
The following objects are masked from 'package:stats':

    filter, lag
The following objects are masked from 'package:base':

    intersect, setdiff, setequal, union
Loading required package: ggplot2
Loading required package: PerformanceAnalytics
Loading required package: xts
Loading required package: zoo

Attaching package: 'zoo'
The following objects are masked from 'package:base':

    as.Date, as.Date.numeric

######################### Warning from 'xts' package ##########################
#                                                                             #
# The dplyr lag() function breaks how base R's lag() function is supposed to  #
# work, which breaks lag(my_xts). Calls to lag(my_xts) that you type or       #
# source() into this session won't work correctly.                            #
#                                                                             #
# Use stats::lag() to make sure you're not using dplyr::lag(), or you can add #
# conflictRules('dplyr', exclude = 'lag') to your .Rprofile to stop           #
# dplyr from breaking base R's lag() function.                                #
#                                                                             #
# Code in packages is not affected. It's protected by R's namespace mechanism #
# Set `options(xts.warn_dplyr_breaks_lag = FALSE)` to suppress this warning.  #
#                                                                             #
###############################################################################

Attaching package: 'xts'
The following objects are masked from 'package:dplyr':

    first, last
The following objects are masked from 'package:gdata':

    first, last

Attaching package: 'PerformanceAnalytics'
The following object is masked from 'package:graphics':

    legend
               rstan                gdata            bayesplot 
                TRUE                 TRUE                 TRUE 
             stringr                dplyr              ggplot2 
                TRUE                 TRUE                 TRUE 
PerformanceAnalytics 
                TRUE 
datapath <- '/fast/work/groups/ag_schlagenhauf/B01_FP1_WP2/WP2_ILT_CODE/02_Behav_and_Comp_Modeling/'
out_path <- '/fast/work/groups/ag_schlagenhauf/B01_FP1_WP2/WP2_ILT_CODE/02_Behav_and_Comp_Modeling/Output'
behavpath <- '/fast/work/groups/ag_schlagenhauf/B01_FP1_WP2/ILT_DATA'
  
# load files containing true parameters used as input for simulation
orig_file <- 'fit_n58_2024-04-12_bandit2arm_delta_PH_withC_hierarchical_group_estimation1_delta0.9_stepsize0.5.rds'
orig_fit <- readRDS(file.path(out_path, orig_file)) # Stan model output

# load simulation outout file containing y_pred and transformed parameters
sim_file <- 'sim_2024-04-13_bandit2arm_delta_PH_withC_hierarchical_group_sim_n58.rds'
sim_fit <- readRDS(file.path(out_path, 'Parameter_Recovery', sim_file)) # Stan model output

# load simulated data fitting results
recovery_file <- 'recovery_2024-04-13_bandit2arm_delta_PH_withC_hierarchical_group_n58.rds'
recovery_fit <- readRDS(file.path(out_path, 'Parameter_Recovery', recovery_file)) # Stan model output

color_scheme_set("mix-blue-pink")
# Load true parameters

## extract posterior means for all parameters to use them as input for simulation
  
### posterior means of parameters as input for simulation
true_mu <- as.vector(summary(orig_fit, pars="mu")$summary[, c("mean")]) 

true_A_sub_m <- as.vector(summary(orig_fit, pars="A_sub_m")$summary[, c("mean")]) 
true_tau_sub_m <- as.vector(summary(orig_fit, pars="tau_sub_m")$summary[, c("mean")]) 
true_gamma_sub_m <- as.vector(summary(orig_fit, pars="gamma_sub_m")$summary[, c("mean")]) 
true_C_sub_m <- as.vector(summary(orig_fit, pars="C_sub_m")$summary[, c("mean")]) 

true_A_subj_s <- as.vector(summary(orig_fit, pars="A_subj_s")$summary[, c("mean")]) 
true_tau_subj_s <- as.vector(summary(orig_fit, pars="tau_subj_s")$summary[, c("mean")]) 
true_gamma_subj_s <- as.vector(summary(orig_fit, pars="gamma_subj_s")$summary[, c("mean")]) 
true_C_subj_s <- as.vector(summary(orig_fit, pars="C_subj_s")$summary[, c("mean")]) 

true_A_subj_raw <- as.vector(summary(orig_fit, pars="A_subj_raw")$summary[, c("mean")]) 
true_tau_subj_raw <- as.vector(summary(orig_fit, pars="tau_subj_raw")$summary[, c("mean")]) 
true_gamma_subj_raw <- as.vector(summary(orig_fit, pars="gamma_subj_raw")$summary[, c("mean")]) 
true_C_subj_raw <- as.vector(summary(orig_fit, pars="C_subj_raw")$summary[, c("mean")]) 

### transformed parameters saved during simulation
sim_posterior <- extract(sim_fit)

true_A <- as.vector(sim_posterior$A[1,,])
true_tau <- as.vector(sim_posterior$tau[1,,])
true_gamma <- as.vector(sim_posterior$gamma[1,,])
true_C <- as.vector(sim_posterior$C_const[1,,])

true_mu_A <- as.vector(sim_posterior$mu_A[1])
true_mu_tau <- as.vector(sim_posterior$mu_tau[1])
true_mu_gamma <- as.vector(sim_posterior$mu_gamma[1])
true_mu_C <- as.vector(sim_posterior$mu_C[1])

## extract parameter values based on simulated data
recovered_mu <- as.matrix(recovery_fit, pars = "mu")

recovered_A_sub_m <- as.matrix(recovery_fit, pars = "A_sub_m")
recovered_tau_sub_m <- as.matrix(recovery_fit, pars = "tau_sub_m")
recovered_gamma_sub_m <- as.matrix(recovery_fit, pars = "gamma_sub_m")
recovered_C_sub_m <- as.matrix(recovery_fit, pars = "C_sub_m")

recovered_A_subj_s <- as.matrix(recovery_fit, pars = "A_subj_s")
recovered_tau_subj_s <- as.matrix(recovery_fit, pars = "tau_subj_s")
recovered_gamma_subj_s <- as.matrix(recovery_fit, pars = "gamma_subj_s")
recovered_C_subj_s <- as.matrix(recovery_fit, pars = "C_subj_s")

recovered_A_subj_raw <- as.matrix(recovery_fit, pars = "A_subj_raw")
recovered_tau_subj_raw <- as.matrix(recovery_fit, pars = "tau_subj_raw")
recovered_gamma_subj_raw <- as.matrix(recovery_fit, pars = "gamma_subj_raw")
recovered_C_subj_raw <- as.matrix(recovery_fit, pars = "C_subj_raw")

recovered_A <- as.matrix(recovery_fit, pars = "A")
recovered_tau <- as.matrix(recovery_fit, pars = "tau")
recovered_gamma <- as.matrix(recovery_fit, pars = "gamma")
recovered_C <- as.matrix(recovery_fit, pars = "C_const")

recovered_A_mean_jui <- summary(recovery_fit, pars="A")$summary[, c("mean")][c(TRUE,rep(FALSE,1))] # mean of parameter values based on simulated data, ordered as true data
recovered_A_mean_alc <- summary(recovery_fit, pars="A")$summary[, c("mean")][c(rep(FALSE,1),TRUE)]
recovered_A_mean <- c(recovered_A_mean_jui,recovered_A_mean_alc)
recovered_tau_mean_jui <- summary(recovery_fit, pars="tau")$summary[, c("mean")][c(TRUE,rep(FALSE,1))]
recovered_tau_mean_alc <- summary(recovery_fit, pars="tau")$summary[, c("mean")][c(rep(FALSE,1),TRUE)]
recovered_tau_mean <- c(recovered_tau_mean_jui,recovered_tau_mean_alc)
recovered_gamma_mean_jui <- summary(recovery_fit, pars="gamma")$summary[, c("mean")][c(TRUE,rep(FALSE,1))]
recovered_gamma_mean_alc <- summary(recovery_fit, pars="gamma")$summary[, c("mean")][c(rep(FALSE,1),TRUE)]
recovered_gamma_mean <- c(recovered_gamma_mean_jui,recovered_gamma_mean_alc)
recovered_C_mean_jui <- summary(recovery_fit, pars="C_const")$summary[, c("mean")][c(TRUE,rep(FALSE,1))]
recovered_C_mean_alc <- summary(recovery_fit, pars="C_const")$summary[, c("mean")][c(rep(FALSE,1),TRUE)]
recovered_C_mean <- c(recovered_C_mean_jui,recovered_C_mean_alc)

recovered_mu_A <- as.matrix(recovery_fit, pars = "mu_A")
recovered_mu_tau <- as.matrix(recovery_fit, pars = "mu_tau")
recovered_mu_gamma <- as.matrix(recovery_fit, pars = "mu_gamma")
recovered_mu_C <- as.matrix(recovery_fit, pars = "mu_C")

2 Recovery plots

## Compare true and recovered parameters

# mu (raw and transformed)
mcmc_recover_intervals(recovered_mu, true_mu, prob = 0.5, prob_outer = 0.95)

mcmc_recover_intervals(recovered_mu_A, true_mu_A, prob = 0.5, prob_outer = 0.95)

mcmc_recover_intervals(recovered_mu_tau, true_mu_tau, prob = 0.5, prob_outer = 0.95)

mcmc_recover_intervals(recovered_mu_gamma, true_mu_gamma, prob = 0.5, prob_outer = 0.95)

mcmc_recover_intervals(recovered_mu_C, true_mu_C, prob = 0.5, prob_outer = 0.95)

# sigma
mcmc_recover_intervals(recovered_A_subj_s, true_A_subj_s, prob = 0.5, prob_outer = 0.95)

mcmc_recover_intervals(recovered_tau_subj_s, true_tau_subj_s, prob = 0.5, prob_outer = 0.95)

mcmc_recover_intervals(recovered_gamma_subj_s, true_gamma_subj_s, prob = 0.5, prob_outer = 0.95)

mcmc_recover_intervals(recovered_C_subj_s, true_C_subj_s, prob = 0.5, prob_outer = 0.95)

# fixed effects
mcmc_recover_intervals(recovered_A_sub_m, true_A_sub_m, prob = 0.5, prob_outer = 0.95)

mcmc_recover_intervals(recovered_tau_sub_m, true_tau_sub_m, prob = 0.5, prob_outer = 0.95)

mcmc_recover_intervals(recovered_gamma_sub_m, true_gamma_sub_m, prob = 0.5, prob_outer = 0.95)

mcmc_recover_intervals(recovered_C_sub_m, true_C_sub_m, prob = 0.5, prob_outer = 0.95)

# individual distances from mu
mcmc_recover_intervals(recovered_A_subj_raw, true_A_subj_raw, prob = 0.5, prob_outer = 0.95, 
                       batch = c(rep(1,10), rep(2,10), rep(3,10), rep(4,10), rep(5,10), rep(6,8)), # adapt last number to 6 or 8 depending on sample size
                       facet_args = list(ncol = 1))

mcmc_recover_intervals(recovered_tau_subj_raw, true_tau_subj_raw, prob = 0.5, prob_outer = 0.95, 
                       batch = c(rep(1,10), rep(2,10), rep(3,10), rep(4,10), rep(5,10), rep(6,8)),
                       facet_args = list(ncol = 1))

mcmc_recover_intervals(recovered_gamma_subj_raw, true_gamma_subj_raw, prob = 0.5, prob_outer = 0.95, 
                       batch = c(rep(1,10), rep(2,10), rep(3,10), rep(4,10), rep(5,10), rep(6,8)),
                       facet_args = list(ncol = 1))

mcmc_recover_intervals(recovered_C_subj_raw, true_C_subj_raw, prob = 0.5, prob_outer = 0.95, 
                       batch = c(rep(1,10), rep(2,10), rep(3,10), rep(4,10), rep(5,10), rep(6,8)),
                       facet_args = list(ncol = 1))

# transformed individual parameters
mcmc_recover_intervals(recovered_A, true_A, prob = 0.5, prob_outer = 0.95, 
                       batch = c(rep(1,10), rep(2,10), rep(3,10), rep(4,10), rep(5,10),
                                 rep(6,10), rep(7,10), rep(8,10), rep(9,10),
                                 rep(10,10), rep(11,10), rep(12,6)), # adapt last number to 2 or 6 depending on sample size
                       facet_args = list(ncol = 1))

mcmc_recover_intervals(recovered_tau, true_tau, prob = 0.5, prob_outer = 0.95, 
                       batch = c(rep(1,10), rep(2,10), rep(3,10), rep(4,10), rep(5,10),
                                 rep(6,10), rep(7,10), rep(8,10), rep(9,10),
                                 rep(10,10), rep(11,10), rep(12,6)),
                       facet_args = list(ncol = 1))

mcmc_recover_intervals(recovered_gamma, true_gamma, prob = 0.5, prob_outer = 0.95, 
                       batch = c(rep(1,10), rep(2,10), rep(3,10), rep(4,10), rep(5,10),
                                 rep(6,10), rep(7,10), rep(8,10), rep(9,10),
                                 rep(10,10), rep(11,10), rep(12,6)),
                       facet_args = list(ncol = 1))

mcmc_recover_intervals(recovered_C, true_C, prob = 0.5, prob_outer = 0.95, 
                       batch = c(rep(1,10), rep(2,10), rep(3,10), rep(4,10), rep(5,10),
                                 rep(6,10), rep(7,10), rep(8,10), rep(9,10),
                                 rep(10,10), rep(11,10), rep(12,6)),
                       facet_args = list(ncol = 1))

3 Correlation btw. true and recovered inidivual parameters

param_df <- data.frame(true_A,true_tau,true_gamma,true_C,recovered_A_mean,recovered_tau_mean,recovered_gamma_mean,recovered_C_mean)
cor(param_df)
                         true_A    true_tau  true_gamma     true_C
true_A                1.0000000  0.16824252  0.12100185 -0.3285231
true_tau              0.1682425  1.00000000  0.38211162 -0.2982736
true_gamma            0.1210018  0.38211162  1.00000000 -0.2332107
true_C               -0.3285231 -0.29827360 -0.23321068  1.0000000
recovered_A_mean      0.6746725  0.27759020  0.29248871 -0.2675338
recovered_tau_mean    0.2465473  0.81971023  0.15075853 -0.1516714
recovered_gamma_mean -0.1795593  0.09689896  0.66473578  0.1756658
recovered_C_mean     -0.3364268 -0.08691413 -0.09336083  0.5833379
                     recovered_A_mean recovered_tau_mean recovered_gamma_mean
true_A                     0.67467248        0.246547317         -0.179559295
true_tau                   0.27759020        0.819710231          0.096898964
true_gamma                 0.29248871        0.150758527          0.664735776
true_C                    -0.26753382       -0.151671411          0.175665789
recovered_A_mean           1.00000000        0.151000787         -0.090037972
recovered_tau_mean         0.15100079        1.000000000         -0.002567954
recovered_gamma_mean      -0.09003797       -0.002567954          1.000000000
recovered_C_mean          -0.19262606       -0.102190484          0.356525304
                     recovered_C_mean
true_A                    -0.33642681
true_tau                  -0.08691413
true_gamma                -0.09336083
true_C                     0.58333787
recovered_A_mean          -0.19262606
recovered_tau_mean        -0.10219048
recovered_gamma_mean       0.35652530
recovered_C_mean           1.00000000
chart.Correlation(param_df, histogram=TRUE, pch=19)
Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter